오늘 풀어본 문제는 백준의 1208번 문제1이다. 문제 풀이에 사용한 언어는 C++ 이다.
이 문제의 내용과 조건은 다음과 같다.
$N$ 개의 정수로 이루어진 수열이 있을 때, 크기가 양수인 부분수열 중에서 그 수열의 원소를 다 더한 값이 $S$ 가 되는 경우의 수를 구하는 프로그램을 작성하시오.
첫째 줄에 정수의 개수를 나타내는 $N$ 과 정수 $S$ 가 주어진다. $(1 \le N \le 40, \vert S \vert \le 1,000,000)$ 둘째 줄에 $N$ 개의 정수가 빈 칸을 사이에 두고 주어진다. 주어지는 정수의 절댓값은 $100,000$ 을 넘지 않는다.
첫째 줄에 합이 $S$ 가 되는 부분수열의 개수를 출력한다.
이 문제를 해결하기 위해서는 MITM (Meet In The Middle) 이라는 방법을 이용해야 한다. MITM은 브루트 포스 방식을 사용할 때 소요 시간을 조금 더 시간을 줄일 수 있게 도와주는 기법이다.
코드는 다음과 같이 작성하였다.
#include <bits/stdc++.h>
using namespace std;
int getSum(const vector<int>& arr, int bitInfo);
int main(void) {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
vector<int> a, b;
map<int, int> half;
int N, S, result = 0;
cin >> N >> S;
for (int i=0; i<(N/2); i++) {
int temp;
cin >> temp;
a.emplace_back(temp);
}
for (int i=(N/2); i<N; i++) {
int temp;
cin >> temp;
b.emplace_back(temp);
}
for (int i = 0; i < (1 << (N/2)); i++) {
int sum = getSum(a, i);
if (half.find(sum) == half.end()) {
half[sum] = 1;
}
else {
half[sum]++;
}
}
for (int i = 0; i < (1 << (N-(N/2))); i++) {
int sum = getSum(b, i);
result += half[S - sum];
}
if (S == 0) {
result--;
}
cout << result;
return 0;
}
int getSum(const vector<int>& arr, int bitInfo) {
int result = 0;
for (int i = 0; bitInfo > 0; i++) {
if (bitInfo % 2 == 1) {
result += arr[i];
}
bitInfo >>= 1;
}
return result;
}
실행 결과 ‘틀렸습니다’ 가 떴다.
틀린 이유를 찾으러 질문게시판을 돌아다닌 결과, 결과가 int
의 범위를 벗어날 수 있다는 것을 알게되었다. 그래서 long long int
로 바꾸는 것을 시도해 보았다.
코드는 다음과 같이 수정하였다.
#include <bits/stdc++.h>
using namespace std;
int getSum(const vector<int>& arr, int bitInfo);
int main(void) {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
vector<int> a, b;
map<int, int> half;
int N, S;
long long int result = 0;
cin >> N >> S;
for (int i=0; i<(N/2); i++) {
int temp;
cin >> temp;
a.emplace_back(temp);
}
for (int i=(N/2); i<N; i++) {
int temp;
cin >> temp;
b.emplace_back(temp);
}
for (int i = 0; i < (1 << (N/2)); i++) {
int sum = getSum(a, i);
if (half.find(sum) == half.end()) {
half[sum] = 1;
}
else {
half[sum]++;
}
}
for (int i = 0; i < (1 << (N-(N/2))); i++) {
int sum = getSum(b, i);
result += half[S - sum];
}
if (S == 0) {
result--;
}
cout << result;
return 0;
}
int getSum(const vector<int>& arr, int bitInfo) {
int result = 0;
for (int i = 0; bitInfo > 0; i++) {
if (bitInfo % 2 == 1) {
result += arr[i];
}
bitInfo >>= 1;
}
return result;
}
실행 결과 ‘시간 초과’ 가 떴다.
여전히 시간 초과가 발생하길래 뭐가 문제일까 질문게시판을 돌아다녔고, map
의 값의 자료형을 long long int
로 해야된다는 식의 내용을 보고 한번 바꾸어 제출해보기로 했다.
코드는 다음과 같이 수정하였다.
#include <bits/stdc++.h>
using namespace std;
int getSum(const vector<int>& arr, int bitInfo);
int main(void) {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
vector<int> a, b;
map<int, long long int> half;
int N, S;
long long int result = 0;
cin >> N >> S;
for (int i=0; i<(N/2); i++) {
int temp;
cin >> temp;
a.emplace_back(temp);
}
for (int i=(N/2); i<N; i++) {
int temp;
cin >> temp;
b.emplace_back(temp);
}
for (int i = 0; i < (1 << (N/2)); i++) {
int sum = getSum(a, i);
if (half.find(sum) == half.end()) {
half[sum] = 1;
}
else {
half[sum]++;
}
}
for (int i = 0; i < (1 << (N-(N/2))); i++) {
int sum = getSum(b, i);
result += half[S - sum];
}
if (S == 0) {
result--;
}
cout << result;
return 0;
}
int getSum(const vector<int>& arr, int bitInfo) {
int result = 0;
for (int i = 0; bitInfo > 0; i++) {
if (bitInfo % 2 == 1) {
result += arr[i];
}
bitInfo >>= 1;
}
return result;
}
실행 결과 ‘시간 초과’ 가 떴다.
결국 최후의 수단을 쓰기로 했다. map
대신 unordered_map
을 사용하기로 했다. map
은 Self-balanced 이진탐색트리로 구현되지만, unordered_map
은 해시테이블로 구현되기 때문에 시간을 더 적게 소요할 것이라고 생각했다.
코드는 다음과 같이 수정하였다.
#include <bits/stdc++.h>
#include <unordered_map>
using namespace std;
int getSum(const vector<int>& arr, int bitInfo);
int main(void) {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
vector<int> a, b;
unordered_map<int, long long int> half;
int N, S;
long long int result = 0;
cin >> N >> S;
for (int i=0; i<(N/2); i++) {
int temp;
cin >> temp;
a.emplace_back(temp);
}
for (int i=(N/2); i<N; i++) {
int temp;
cin >> temp;
b.emplace_back(temp);
}
for (int i = 0; i < (1 << (N/2)); i++) {
int sum = getSum(a, i);
if (half.find(sum) == half.end()) {
half[sum] = 1;
}
else {
half[sum]++;
}
}
for (int i = 0; i < (1 << (N-(N/2))); i++) {
int sum = getSum(b, i);
result += half[S - sum];
}
if (S == 0) {
result--;
}
cout << result;
return 0;
}
int getSum(const vector<int>& arr, int bitInfo) {
int result = 0;
for (int i = 0; bitInfo > 0; i++) {
if (bitInfo % 2 == 1) {
result += arr[i];
}
bitInfo >>= 1;
}
return result;
}
그러자 모든 테스트 케이스를 통과하고 정답이 나오는 것을 확인할 수 있었다.
시간 초과 문제를 unordered_map
을 이용하여 해결하는 건 항상 기분이 상쾌하다. 어떤 사람들은 최악의 경우 $O(n)$ 이라며 쓰지 말라고 하는 사람도 있지만, 내 마음 속에선 언제나 $O(1)$ 이다. 해시테이블 최고!!!
오늘의 PS는 여기까지!
1: https://www.acmicpc.net/problem/1208